{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Flax seq2seq Example\n",
    "\n",
    "<a href=\"https://colab.research.google.com/github/google/flax/blob/main/examples/seq2seq/seq2seq.ipynb\" ><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>\n",
    "\n",
    "Demonstration notebook for\n",
    "https://github.com/google/flax/tree/main/examples/seq2seq\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The **Flax Notebook Workflow**:\n",
    "\n",
    "1. Run the entire notebook end-to-end and check out the outputs.\n",
    "   - This will open Python files in the right-hand editor!\n",
    "   - You'll be able to interactively explore metrics in TensorBoard.\n",
    "2. Change some of the hyperparameters in the command-line flags in `train.py` for different hyperparameters. Check out the updated TensorBoard plots.\n",
    "3. Update the code in `train.py`, `models.py`, and `input_pipeline.py`. \n",
    "   Thanks to `%autoreload`, any changes you make in the file will \n",
    "   automatically appear in the notebook. Some ideas to get you started:\n",
    "   - Change the model.\n",
    "   - Log some per-batch metrics during training.\n",
    "   - Add new hyperparameters to `models.py` and use them in `train.py`.\n",
    "   - Train on a different vocabulary by initializing `CharacterTable` with a\n",
    "     different character set.\n",
    "4. At any time, feel free to paste code from the source code into the notebook\n",
    "   and modify it directly there!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "outputId": "4c0a705c-8d7e-44cc-d851-873a40ac115e"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[K     |████████████████████████████████| 77 kB 3.1 MB/s \n",
      "\u001b[K     |████████████████████████████████| 176 kB 30.2 MB/s \n",
      "\u001b[K     |████████████████████████████████| 77 kB 5.2 MB/s \n",
      "\u001b[K     |████████████████████████████████| 136 kB 45.5 MB/s \n",
      "\u001b[K     |████████████████████████████████| 65 kB 2.8 MB/s \n",
      "\u001b[K     |████████████████████████████████| 462 kB 44.3 MB/s \n",
      "\u001b[?25h  Building wheel for ml-collections (setup.py) ... \u001b[?25l\u001b[?25hdone\n"
     ]
    }
   ],
   "source": [
    "# Install CLU & Flax.\n",
    "!pip install -q clu flax"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "example_directory = 'examples/seq2seq'\n",
    "editor_relpaths = ('train.py', 'input_pipeline.py', 'models.py')\n",
    "\n",
    "repo, branch = 'https://github.com/google/flax', 'main'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "outputId": "4801432e-4090-4b13-f0f2-d99a3039ce47"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Cloning into 'flaxrepo'...\n",
      "remote: Enumerating objects: 349, done.\u001b[K\n",
      "remote: Counting objects: 100% (349/349), done.\u001b[K\n",
      "remote: Compressing objects: 100% (286/286), done.\u001b[K\n",
      "remote: Total 349 (delta 63), reused 220 (delta 51), pack-reused 0\u001b[K\n",
      "Receiving objects: 100% (349/349), 2.12 MiB | 13.39 MiB/s, done.\n",
      "Resolving deltas: 100% (63/63), done.\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<h1 style=\"color:red;\" class=\"blink\">WARNING : Editing in VM - changes lost after reboot!!</h1>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/javascript": [
       "\n",
       "      ((filepath) => {{\n",
       "        if (!google.colab.kernel.accessAllowed) {{\n",
       "          return;\n",
       "        }}\n",
       "        google.colab.files.view(filepath);\n",
       "      }})(\"/content/examples/seq2seq/train.py\")"
      ],
      "text/plain": [
       "<IPython.core.display.Javascript object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/javascript": [
       "\n",
       "      ((filepath) => {{\n",
       "        if (!google.colab.kernel.accessAllowed) {{\n",
       "          return;\n",
       "        }}\n",
       "        google.colab.files.view(filepath);\n",
       "      }})(\"/content/examples/seq2seq/input_pipeline.py\")"
      ],
      "text/plain": [
       "<IPython.core.display.Javascript object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/javascript": [
       "\n",
       "      ((filepath) => {{\n",
       "        if (!google.colab.kernel.accessAllowed) {{\n",
       "          return;\n",
       "        }}\n",
       "        google.colab.files.view(filepath);\n",
       "      }})(\"/content/examples/seq2seq/models.py\")"
      ],
      "text/plain": [
       "<IPython.core.display.Javascript object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# (If you run this code in Jupyter[lab], then you're already in the\n",
    "#  example directory and nothing needs to be done.)\n",
    "\n",
    "#@markdown **Fetch newest Flax, copy example code**\n",
    "#@markdown\n",
    "#@markdown **If you select no** below, then the files will be stored on the\n",
    "#@markdown *ephemeral* Colab VM. **After some time of inactivity, this VM will\n",
    "#@markdown be restarted and any changes are lost**.\n",
    "#@markdown\n",
    "#@markdown **If you select yes** below, then you will be asked for your\n",
    "#@markdown credentials to mount your personal Google Drive. In this case, all\n",
    "#@markdown changes you make will be *persisted*, and even if you re-run the\n",
    "#@markdown Colab later on, the files will still be the same (you can of course\n",
    "#@markdown remove directories inside your Drive's `flax/` root if you want to\n",
    "#@markdown manually revert these files).\n",
    "\n",
    "if 'google.colab' in str(get_ipython()):\n",
    "  import os\n",
    "  os.chdir('/content')\n",
    "  # Download Flax repo from Github.\n",
    "  if not os.path.isdir('flaxrepo'):\n",
    "    !git clone --depth=1 -b $branch $repo flaxrepo\n",
    "  # Copy example files & change directory.\n",
    "  mount_gdrive = 'no' #@param ['yes', 'no']\n",
    "  if mount_gdrive == 'yes':\n",
    "    DISCLAIMER = 'Note : Editing in your Google Drive, changes will persist.'\n",
    "    from google.colab import drive\n",
    "    drive.mount('/content/gdrive')\n",
    "    example_root_path = f'/content/gdrive/My Drive/flax/{example_directory}'\n",
    "  else:\n",
    "    DISCLAIMER = 'WARNING : Editing in VM - changes lost after reboot!!'\n",
    "    example_root_path = f'/content/{example_directory}'\n",
    "    from IPython import display\n",
    "    display.display(display.HTML(\n",
    "        f'<h1 style=\"color:red;\" class=\"blink\">{DISCLAIMER}</h1>'))\n",
    "  if not os.path.isdir(example_root_path):\n",
    "    os.makedirs(example_root_path)\n",
    "    !cp -r flaxrepo/$example_directory/* \"$example_root_path\"\n",
    "  os.chdir(example_root_path)\n",
    "  from google.colab import files\n",
    "  for relpath in editor_relpaths:\n",
    "    s = open(f'{example_root_path}/{relpath}').read()\n",
    "    open(f'{example_root_path}/{relpath}', 'w').write(\n",
    "        f'## {DISCLAIMER}\\n' + '#' * (len(DISCLAIMER) + 3) + '\\n\\n' + s)\n",
    "    files.view(f'{example_root_path}/{relpath}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "outputId": "a292a7a2-ae3c-4518-af28-9c2fa0ed2d7b"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/content/examples/seq2seq\n"
     ]
    }
   ],
   "source": [
    "# Note : In Colab, above cell changed the working directory.\n",
    "!pwd"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from absl import app\n",
    "app.parse_flags_with_usage(['seq2seq'])\n",
    "\n",
    "from absl import logging\n",
    "logging.set_verbosity(logging.INFO)\n",
    "\n",
    "import jax"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "outputId": "7e1a29ce-9d8b-4715-ce60-9eae100a1df3",
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The autoreload extension is already loaded. To reload it, use:\n",
      "  %reload_ext autoreload\n"
     ]
    }
   ],
   "source": [
    "# Local imports from current directory - auto reload.\n",
    "# Any changes you make to the three imported files will appear automatically.\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "import input_pipeline\n",
    "import models\n",
    "import train"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "outputId": "cb5f7f6e-1e6f-40ff-e0d6-5b428511d75b"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[('72+789', '=861'),\n",
       " ('58+858', '=916'),\n",
       " ('77+358', '=435'),\n",
       " ('99+264', '=363'),\n",
       " ('94+115', '=209')]"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Examples are generated on the fly.\n",
    "ctable = input_pipeline.CharacterTable('0123456789+= ')\n",
    "list(ctable.generate_examples(5))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "outputId": "b58ea813-e757-4cc5-f3ba-3cb0f05d35a6"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
       "       [0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
       "       [0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
       "       [0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],\n",
       "       [0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],\n",
       "       [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
       "       [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
       "       [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],\n",
       "      dtype=float32)"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "batch = ctable.get_batch(5)\n",
    "# A single query (/answer) is one-hot encoded.\n",
    "batch['query'][0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "outputId": "3b33e061-f0b5-42d7-ad49-5058e8fd3b90"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array(['1+243'], dtype='<U5')"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Note how CTABLE encodes PAD=0, EOS=1, '0'=2, '1'=3, ...\n",
    "ctable.decode_onehot(batch['query'][:1])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get a live update during training - use the \"refresh\" button!\n",
    "# (In Jupyter[lab] start \"tensorboard\" in the local directory instead.)\n",
    "if 'google.colab' in str(get_ipython()):\n",
    "  %load_ext tensorboard\n",
    "  %tensorboard --logdir=./workdirs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "workdir = f'./workdirs/{int(time.time())}'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "outputId": "e49554e2-9336-4d97-a1e2-82b9e98407da"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['seq2seq']"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Train 2k steps & log 20 times.\n",
    "app.parse_flags_with_usage([\n",
    "    'seq2seq',\n",
    "    '--num_train_steps=2000',\n",
    "    '--decode_frequency=100',\n",
    "])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "outputId": "49396889-35b0-4a11-8b8a-e67624be32a7"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:absl:[100] accuracy=0.015625, loss=0.6936355233192444\n",
      "INFO:absl:DECODE: 96+964 = 1002 (INCORRECT) correct=1060\n",
      "INFO:absl:DECODE: 71+545 = 608 (INCORRECT) correct=616\n",
      "INFO:absl:DECODE: 42+729 = 730 (INCORRECT) correct=771\n",
      "INFO:absl:DECODE: 28+588 = 684 (INCORRECT) correct=616\n",
      "INFO:absl:DECODE: 39+648 = 618 (INCORRECT) correct=687\n",
      "INFO:absl:[200] accuracy=0.03125, loss=0.5422528982162476\n",
      "INFO:absl:DECODE: 18+70 = 70 (INCORRECT) correct=88\n",
      "INFO:absl:DECODE: 43+123 = 145 (INCORRECT) correct=166\n",
      "INFO:absl:DECODE: 72+406 = 460 (INCORRECT) correct=478\n",
      "INFO:absl:DECODE: 53+443 = 492 (INCORRECT) correct=496\n",
      "INFO:absl:DECODE: 74+844 = 936 (INCORRECT) correct=918\n",
      "INFO:absl:[300] accuracy=0.0703125, loss=0.462927907705307\n",
      "INFO:absl:DECODE: 40+598 = 643 (INCORRECT) correct=638\n",
      "INFO:absl:DECODE: 2+72 = 75 (INCORRECT) correct=74\n",
      "INFO:absl:DECODE: 70+742 = 814 (INCORRECT) correct=812\n",
      "INFO:absl:DECODE: 22+943 = 963 (INCORRECT) correct=965\n",
      "INFO:absl:DECODE: 85+890 = 975 (CORRECT)\n",
      "INFO:absl:[400] accuracy=0.0390625, loss=0.4398380219936371\n",
      "INFO:absl:DECODE: 79+562 = 647 (INCORRECT) correct=641\n",
      "INFO:absl:DECODE: 47+987 = 1043 (INCORRECT) correct=1034\n",
      "INFO:absl:DECODE: 92+959 = 1064 (INCORRECT) correct=1051\n",
      "INFO:absl:DECODE: 38+291 = 334 (INCORRECT) correct=329\n",
      "INFO:absl:DECODE: 95+476 = 571 (CORRECT)\n",
      "INFO:absl:[500] accuracy=0.0703125, loss=0.4079001545906067\n",
      "INFO:absl:DECODE: 9+871 = 872 (INCORRECT) correct=880\n",
      "INFO:absl:DECODE: 58+633 = 696 (INCORRECT) correct=691\n",
      "INFO:absl:DECODE: 6+98 = 113 (INCORRECT) correct=104\n",
      "INFO:absl:DECODE: 8+325 = 345 (INCORRECT) correct=333\n",
      "INFO:absl:DECODE: 43+540 = 580 (INCORRECT) correct=583\n",
      "INFO:absl:[600] accuracy=0.0390625, loss=0.4265238344669342\n",
      "INFO:absl:DECODE: 67+554 = 613 (INCORRECT) correct=621\n",
      "INFO:absl:DECODE: 38+189 = 233 (INCORRECT) correct=227\n",
      "INFO:absl:DECODE: 28+843 = 878 (INCORRECT) correct=871\n",
      "INFO:absl:DECODE: 11+948 = 969 (INCORRECT) correct=959\n",
      "INFO:absl:DECODE: 9+121 = 148 (INCORRECT) correct=130\n",
      "INFO:absl:[700] accuracy=0.078125, loss=0.38418614864349365\n",
      "INFO:absl:DECODE: 38+528 = 562 (INCORRECT) correct=566\n",
      "INFO:absl:DECODE: 57+317 = 372 (INCORRECT) correct=374\n",
      "INFO:absl:DECODE: 64+246 = 300 (INCORRECT) correct=310\n",
      "INFO:absl:DECODE: 73+296 = 373 (INCORRECT) correct=369\n",
      "INFO:absl:DECODE: 14+321 = 336 (INCORRECT) correct=335\n",
      "INFO:absl:[800] accuracy=0.1171875, loss=0.3908345401287079\n",
      "INFO:absl:DECODE: 26+902 = 931 (INCORRECT) correct=928\n",
      "INFO:absl:DECODE: 15+818 = 843 (INCORRECT) correct=833\n",
      "INFO:absl:DECODE: 19+348 = 359 (INCORRECT) correct=367\n",
      "INFO:absl:DECODE: 79+878 = 959 (INCORRECT) correct=957\n",
      "INFO:absl:DECODE: 64+824 = 889 (INCORRECT) correct=888\n",
      "INFO:absl:[900] accuracy=0.203125, loss=0.3021599352359772\n",
      "INFO:absl:DECODE: 6+477 = 482 (INCORRECT) correct=483\n",
      "INFO:absl:DECODE: 11+18 = 39 (INCORRECT) correct=29\n",
      "INFO:absl:DECODE: 35+777 = 814 (INCORRECT) correct=812\n",
      "INFO:absl:DECODE: 45+156 = 200 (INCORRECT) correct=201\n",
      "INFO:absl:DECODE: 11+432 = 445 (INCORRECT) correct=443\n",
      "INFO:absl:[1000] accuracy=0.6015625, loss=0.16906459629535675\n",
      "INFO:absl:DECODE: 41+156 = 198 (INCORRECT) correct=197\n",
      "INFO:absl:DECODE: 45+177 = 222 (CORRECT)\n",
      "INFO:absl:DECODE: 94+845 = 937 (INCORRECT) correct=939\n",
      "INFO:absl:DECODE: 10+73 = 93 (INCORRECT) correct=83\n",
      "INFO:absl:DECODE: 62+598 = 661 (INCORRECT) correct=660\n",
      "INFO:absl:[1100] accuracy=0.6796875, loss=0.12802045047283173\n",
      "INFO:absl:DECODE: 95+134 = 218 (INCORRECT) correct=229\n",
      "INFO:absl:DECODE: 69+508 = 577 (CORRECT)\n",
      "INFO:absl:DECODE: 74+636 = 710 (CORRECT)\n",
      "INFO:absl:DECODE: 11+522 = 533 (CORRECT)\n",
      "INFO:absl:DECODE: 0+620 = 620 (CORRECT)\n",
      "INFO:absl:[1200] accuracy=0.96875, loss=0.027045272290706635\n",
      "INFO:absl:DECODE: 93+376 = 469 (CORRECT)\n",
      "INFO:absl:DECODE: 65+474 = 549 (INCORRECT) correct=539\n",
      "INFO:absl:DECODE: 33+35 = 58 (INCORRECT) correct=68\n",
      "INFO:absl:DECODE: 87+527 = 615 (INCORRECT) correct=614\n",
      "INFO:absl:DECODE: 40+77 = 107 (INCORRECT) correct=117\n",
      "INFO:absl:[1300] accuracy=0.9453125, loss=0.025903644040226936\n",
      "INFO:absl:DECODE: 9+341 = 350 (CORRECT)\n",
      "INFO:absl:DECODE: 29+59 = 88 (CORRECT)\n",
      "INFO:absl:DECODE: 83+473 = 556 (CORRECT)\n",
      "INFO:absl:DECODE: 51+339 = 390 (CORRECT)\n",
      "INFO:absl:DECODE: 29+63 = 92 (CORRECT)\n",
      "INFO:absl:[1400] accuracy=0.890625, loss=0.03969202935695648\n",
      "INFO:absl:DECODE: 77+177 = 254 (CORRECT)\n",
      "INFO:absl:DECODE: 66+692 = 759 (INCORRECT) correct=758\n",
      "INFO:absl:DECODE: 24+321 = 345 (CORRECT)\n",
      "INFO:absl:DECODE: 23+180 = 203 (CORRECT)\n",
      "INFO:absl:DECODE: 67+898 = 965 (CORRECT)\n",
      "INFO:absl:[1500] accuracy=0.984375, loss=0.009904923848807812\n",
      "INFO:absl:DECODE: 13+119 = 132 (CORRECT)\n",
      "INFO:absl:DECODE: 34+467 = 501 (CORRECT)\n",
      "INFO:absl:DECODE: 20+661 = 681 (CORRECT)\n",
      "INFO:absl:DECODE: 51+779 = 830 (CORRECT)\n",
      "INFO:absl:DECODE: 12+101 = 113 (CORRECT)\n",
      "INFO:absl:[1600] accuracy=0.9375, loss=0.021177640184760094\n",
      "INFO:absl:DECODE: 61+275 = 336 (CORRECT)\n",
      "INFO:absl:DECODE: 48+831 = 879 (CORRECT)\n",
      "INFO:absl:DECODE: 25+628 = 653 (CORRECT)\n",
      "INFO:absl:DECODE: 87+933 = 1020 (CORRECT)\n",
      "INFO:absl:DECODE: 75+405 = 480 (CORRECT)\n",
      "INFO:absl:[1700] accuracy=0.96875, loss=0.015190182253718376\n",
      "INFO:absl:DECODE: 83+347 = 430 (CORRECT)\n",
      "INFO:absl:DECODE: 45+832 = 877 (CORRECT)\n",
      "INFO:absl:DECODE: 47+867 = 914 (CORRECT)\n",
      "INFO:absl:DECODE: 39+450 = 489 (CORRECT)\n",
      "INFO:absl:DECODE: 2+357 = 359 (CORRECT)\n",
      "INFO:absl:[1800] accuracy=0.9921875, loss=0.008457459509372711\n",
      "INFO:absl:DECODE: 64+97 = 161 (CORRECT)\n",
      "INFO:absl:DECODE: 66+335 = 401 (CORRECT)\n",
      "INFO:absl:DECODE: 15+656 = 671 (CORRECT)\n",
      "INFO:absl:DECODE: 40+689 = 729 (CORRECT)\n",
      "INFO:absl:DECODE: 42+104 = 146 (CORRECT)\n",
      "INFO:absl:[1900] accuracy=0.90625, loss=0.026923568919301033\n",
      "INFO:absl:DECODE: 34+92 = 116 (INCORRECT) correct=126\n",
      "INFO:absl:DECODE: 52+466 = 518 (CORRECT)\n",
      "INFO:absl:DECODE: 46+47 = 93 (CORRECT)\n",
      "INFO:absl:DECODE: 61+240 = 301 (CORRECT)\n",
      "INFO:absl:DECODE: 86+951 = 1037 (CORRECT)\n"
     ]
    }
   ],
   "source": [
    "state = train.train_and_evaluate(workdir=workdir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "outputId": "2beaf4e9-b10b-4156-d2d9-187777306de0",
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "***** TensorBoard Uploader *****\n",
      "\n",
      "This will upload your TensorBoard logs to https://tensorboard.dev/ from\n",
      "the following directory:\n",
      "\n",
      "./workdirs\n",
      "\n",
      "This TensorBoard will be visible to everyone. Do not upload sensitive\n",
      "data.\n",
      "\n",
      "Your use of this service is subject to Google's Terms of Service\n",
      "<https://policies.google.com/terms> and Privacy Policy\n",
      "<https://policies.google.com/privacy>, and TensorBoard.dev's Terms of Service\n",
      "<https://tensorboard.dev/policy/terms/>.\n",
      "\n",
      "This notice will not be shown again while you are logged into the uploader.\n",
      "To log out, run `tensorboard dev auth revoke`.\n",
      "\n",
      "Continue? (yes/NO) yes\n",
      "\n",
      "Please visit this URL to authorize this application: https://accounts.google.com/o/oauth2/auth?response_type=code&client_id=373649185512-8v619h5kft38l4456nm2dj4ubeqsrvh6.apps.googleusercontent.com&redirect_uri=urn%3Aietf%3Awg%3Aoauth%3A2.0%3Aoob&scope=openid+https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fuserinfo.email&state=IjociK9llsm6dSiC1TDFvFmJksFy49&prompt=consent&access_type=offline\n",
      "Enter the authorization code: 4/1AX4XfWi6J9MqoDpbZ5Z_jd1AVheW7277VuUoTOEz5_8NRs_3oP9M7S4T81c\n",
      "\n",
      "\n",
      "New experiment created. View your TensorBoard at: https://tensorboard.dev/experiment/pgfmdFlaQTy9odov72ZvVQ/\n",
      "\n",
      "\u001b[1m[2022-02-25T09:34:21]\u001b[0m Started scanning logdir.\n",
      "\u001b[1m[2022-02-25T09:34:22]\u001b[0m Total uploaded: 38 scalars, 0 tensors, 0 binary objects\n",
      "\u001b[1m[2022-02-25T09:34:22]\u001b[0m Done scanning logdir.\n",
      "\n",
      "\n",
      "Done. View your TensorBoard at https://tensorboard.dev/experiment/pgfmdFlaQTy9odov72ZvVQ/\n"
     ]
    }
   ],
   "source": [
    "if 'google.colab' in str(get_ipython()):\n",
    "  #@markdown You can upload the training results directly to https://tensorboard.dev\n",
    "  #@markdown\n",
    "  #@markdown Note that everybody with the link will be able to see the data.\n",
    "  upload_data = 'yes' #@param ['yes', 'no']\n",
    "  if upload_data == 'yes':\n",
    "    !tensorboard dev upload --one_shot --logdir ./workdirs --name 'Flax examples/seq2seq (Colab)'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Inference"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "outputId": "e22b7208-5413-4a63-abfb-b510af60f340"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(1, 8, 15)"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "inputs = ctable.encode_onehot(['2+40'])\n",
    "# batch, max_length, vocab_size\n",
    "inputs.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Using different random seeds generates different samples.\n",
    "preds = train.decode(state.params, inputs, jax.random.key(0), ctable)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "outputId": "e5cdfd75-2c66-4165-8ab7-9fdecde5062a"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array(['42'], dtype='<U2')"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ctable.decode_onehot(preds)"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
